Skip to content

Conversation

@fatemetkl
Copy link
Collaborator

@fatemetkl fatemetkl commented Jan 8, 2026

PR Type

[Fix | Documentation]

Short Description

Clickup Ticket(s): https://app.clickup.com/t/868gy123e

This PR introduces several improvements to the Ensemble Attack code and fixes based on issues we found during experimentation.

  1. RMIA shadow training data: The dataset size used for RMIA shadow training significantly affects results. A new config option, attack_rmia_shadow_training_data_choice, is added.
  2. population data: DOMIAS requires a large population, so we now merge the original attack’s collected population with the experiment’s data for more consistent results.
  3. Reduced memory usage by saving only the synthetic shadow model data instead of full TrainingResult objects.
  4. Multiprocessing and batching for Gower Matrix computation are added to balance runtime and memory based on available resources.
  5. During the test phase, we load and reuse existing RMIA shadow models if they have already been trained and saved for the current experiment.

Several other minor fixes and improvements to the documentation are also included in this PR.

Tests Added

Existing tests are updated.

challenge_data_path: ${target_model.target_model_directory}/${target_model.target_model_name}/challenge_with_id.csv
challenge_label_path: ${target_model.target_model_directory}/${target_model.target_model_name}/challenge_label.csv

target_attack_artifact_dir: ${base_experiment_dir}/target_${target_model.target_model_id}_attack_artifacts/
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This directory was extra and can be removed.

@@ -1,34 +1,36 @@
# Ensemble experiment configuration
Copy link
Collaborator Author

@fatemetkl fatemetkl Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current single config is hard to understand because it mixes many variables and data paths with unclear names inherited from the original attack code. Splitting it into multiple pipeline‑specific configs would improve clarity and maintainability, even if it adds some overhead. Alternatively, improving variable naming within one config could be helpful.

)

population.append(df_real)
population.append(df_real)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug! Thank you for catching this, Sara!

# Load the required dataframes for shadow model training.
# For shadow model training we need master_challenge_train and population data.
# Master challenge is the main training (or fine-tuning) data for the shadow models.
df_master_challenge_train = load_dataframe(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of loading the data here, it is passed to the function.

f"Fine-tuned model {model_id} generated {len(train_result.synthetic_data)} synthetic samples.",
)
attack_data["fine_tuned_results"].append(train_result)
attack_data["fine_tuned_results"].append(train_result.synthetic_data)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need to save the synthetic data.

target_shadow_models_output_path: ${target_model.target_attack_artifact_dir}/tabddpm_${target_model.target_model_id}_shadows_dir
target_shadow_models_output_path: ${base_experiment_dir}/test_all_targets # Sub-directory to store test shadows and results
attack_probabilities_result_path: ${target_model.target_shadow_models_output_path}/test_probabilities/attack_model_${target_model.target_model_id}_proba
attack_rmia_shadow_training_data_choice: "combined" # Options: "combined", "only_challenge", "only_train". This determines which data to use for training RMIA attack model in testing phase.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new config variable. You can read more about the options in select_challenge_data_for_training()'s docstring.



@hydra.main(config_path="configs", config_name="experiment_config", version_base=None)
def run_metaclassifier_testing(
Copy link
Collaborator Author

@fatemetkl fatemetkl Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function executes the attack on a single target model (target_model.target_model_id). However, all target models within an experiment can share the same trained RMIA shadow models. Our current workflow is to train the RMIA shadow models using one target model, and then run tests on all remaining targets in parallel using run_test.sh. Each of these targets simply loads the previously trained RMIA shadow models. This approach was originally designed to speed up testing.

Later, we realized that the main runtime bottleneck (testing phase) is actually the RMIA shadow‑model training step. As a result, a potential refactoring improvement would be to modify this function so that it trains the RMIA shadow models once and then sequentially tests a set of target models within a single function call. This can simplify the testing process with potentially little to no loss in efficiency.

@fatemetkl fatemetkl marked this pull request as ready for review January 12, 2026 17:48
@coderabbitai
Copy link

coderabbitai bot commented Jan 12, 2026

📝 Walkthrough

Walkthrough

This pull request refactors the ensemble attack pipeline from 20k to 10k data, restructuring the data collection workflow to use pre-loaded population and challenge datasets passed as parameters rather than loaded internally, and optimizing distance computations through batched and multiprocessing-enabled Gower calculations. The changes include configuration updates with expanded data splits, removal of on-disk data loading in favor of externally managed DataFrames, refactoring of internal data representations to store synthetic data directly instead of TrainingResult objects, and comprehensive test updates reflecting the new data handling approach.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Title check ❓ Inconclusive The title 'Ft/ensemble changes' is vague and generic, failing to clearly convey the specific improvements and fixes included in the changeset. Use a more descriptive title that highlights the primary changes, such as 'Add multiprocessing to RMIA Gower computation and improve population data handling' or 'Optimize ensemble attack memory usage and add RMIA shadow model reuse'.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The PR description follows the required template with PR Type, Short Description, Clickup Ticket link, detailed explanation of changes, and note about tests being updated.
Docstring Coverage ✅ Passed Docstring coverage is 91.30% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🤖 Fix all issues with AI agents
In @examples/ensemble_attack/configs/experiment_config.yaml:
- Around line 109-110: In the attack_success_computation block update the
target_ids_to_test list to remove the duplicate 26 and insert the missing 27 so
the sequence is correct; locate the target_ids_to_test entry and replace the
duplicated 26 with 27 (ensuring each target ID appears once and includes 27).
- Around line 2-4: Fix the typo in the top comment of the YAML config: replace
"tets_attack_model.py" with "test_attack_model.py" in the first comment line so
the referenced test script name is correct; update the comment string that
currently reads "run_attack.py and tets_attack_model.py" to "run_attack.py and
test_attack_model.py".

In @examples/ensemble_attack/real_data_collection.py:
- Around line 182-195: Remove the duplicated block that repeats the
population_splits/challenge_splits defaults and the redundant save_dir.mkdir
call in examples/ensemble_attack/real_data_collection.py: keep the first
occurrence that sets population_splits = ["train"] and challenge_splits =
["train", "dev", "final"] and remove the second duplicate block (the repeated if
population_splits is None / if challenge_splits is None and the extra
save_dir.mkdir). Ensure only one mkdir(save_dir) and one defaults assignment
remain (so functions or callers relying on those variables still see the
intended defaults).

In @examples/ensemble_attack/run_metaclassifier_training.py:
- Around line 89-92: The log call that prints the reference population path
contains a malformed f-string ("f{config.data_paths.population_path}") so the
literal text "f{...}" will be logged; in the logging statement that references
config.data_paths.population_path (the log(...) call near where df_reference is
used in run_metaclassifier_training), remove the stray leading "f" before the
curly brace so the f-string interpolates the actual path value (i.e., ensure the
f-string only prefixes the whole string once and references
config.data_paths.population_path normally).
- Around line 28-29: The docstring for the metaclassifier training entry
duplicates the parameter description for target_model_synthetic_path; remove the
redundant entry so target_model_synthetic_path appears only once in the
function/module docstring (update the docstring block that lists parameters to
keep a single clear description of target_model_synthetic_path and delete the
duplicate paragraph).

In @examples/ensemble_attack/run_shadow_model_training.py:
- Around line 106-110: The code contains a duplicate assertion: remove the
redundant assertion that checks "trans_id" in df_challenge_train.columns (the
second occurrence that repeats the check at the start of the block) so that you
only assert once for df_challenge_train and keep the existing assertion for
df_population_with_challenge; locate the repeated line referencing
df_challenge_train.columns and delete it.

In @examples/ensemble_attack/test_attack_model.py:
- Around line 96-97: The assignment to shadow_model_paths from
run_shadow_model_training(...) is being discarded by the immediate overwrite
from config.shadow_training.final_shadow_models_path; remove the second
assignment so the returned paths from run_shadow_model_training are used (i.e.,
delete the line that sets shadow_model_paths = [Path(path) for path in
config.shadow_training.final_shadow_models_path]) and ensure any downstream
logic uses the shadow_model_paths variable returned by
run_shadow_model_training; alternatively, if you truly intend to use the config
paths, remove the run_shadow_model_training call instead, but prefer keeping
run_shadow_model_training's return value.
🧹 Nitpick comments (3)
src/midst_toolkit/attacks/ensemble/rmia/rmia_calculation.py (1)

350-388: Consider making n_jobs configurable.

The n_jobs=4 is hardcoded in multiple calls to get_rmia_gower. Consider passing this as a parameter to calculate_rmia_signals for flexibility across different hardware configurations.

tests/integration/attacks/ensemble/test_shadow_model_training.py (1)

68-71: Consider reordering assertions for clearer error messages.

If synthetic_data were None, the type assertion would fail with a confusing message. Consider checking for None first.

Suggested order
     for synthetic_data in shadow_data["fine_tuned_results"]:
-        assert type(synthetic_data) is pd.DataFrame
         assert synthetic_data is not None
+        assert type(synthetic_data) is pd.DataFrame
         assert len(synthetic_data) == 5
examples/ensemble_attack/test_attack_model.py (1)

162-168: Consider making the data split configurable.

The hardcoded data_splits=["test"] with the comment suggesting manual changes for different experiments could be error-prone. Consider extracting this to a config parameter.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cfc4306 and 951abfb.

📒 Files selected for processing (11)
  • examples/ensemble_attack/configs/experiment_config.yaml
  • examples/ensemble_attack/real_data_collection.py
  • examples/ensemble_attack/run_attack.py
  • examples/ensemble_attack/run_metaclassifier_training.py
  • examples/ensemble_attack/run_shadow_model_training.py
  • examples/ensemble_attack/run_train.sh
  • examples/ensemble_attack/test_attack_model.py
  • src/midst_toolkit/attacks/ensemble/rmia/rmia_calculation.py
  • src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py
  • tests/integration/attacks/ensemble/test_shadow_model_training.py
  • tests/unit/attacks/ensemble/test_rmia.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-11T16:08:49.024Z
Learnt from: lotif
Repo: VectorInstitute/midst-toolkit PR: 107
File: examples/gan/synthesize.py:1-47
Timestamp: 2025-12-11T16:08:49.024Z
Learning: When using SDV (version >= 1.18.0), prefer loading a saved CTGANSynthesizer with CTGANSynthesizer.load(filepath) instead of sdv.utils.load_synthesizer(). This applies to Python code across the repo (e.g., any script that loads a CTGANSynthesizer). Ensure the SDV version is >= 1.18.0 before using CTGANSynthesizer.load, and fall back to sdv.utils.load_synthesizer() only if a compatible alternative is required.

Applied to files:

  • src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py
  • examples/ensemble_attack/run_shadow_model_training.py
  • examples/ensemble_attack/real_data_collection.py
  • tests/integration/attacks/ensemble/test_shadow_model_training.py
  • tests/unit/attacks/ensemble/test_rmia.py
  • examples/ensemble_attack/run_metaclassifier_training.py
  • src/midst_toolkit/attacks/ensemble/rmia/rmia_calculation.py
  • examples/ensemble_attack/test_attack_model.py
  • examples/ensemble_attack/run_attack.py
🧬 Code graph analysis (5)
src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py (1)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (2)
  • fine_tune_tabddpm_and_synthesize (158-248)
  • TrainingResult (26-33)
examples/ensemble_attack/run_shadow_model_training.py (1)
src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py (1)
  • train_three_sets_of_shadow_models (309-442)
tests/unit/attacks/ensemble/test_rmia.py (1)
src/midst_toolkit/attacks/ensemble/rmia/rmia_calculation.py (2)
  • Key (23-25)
  • get_rmia_gower (136-215)
examples/ensemble_attack/test_attack_model.py (3)
examples/ensemble_attack/real_data_collection.py (2)
  • AttackType (17-31)
  • collect_midst_data (101-142)
src/midst_toolkit/attacks/ensemble/blending.py (1)
  • MetaClassifierType (21-23)
src/midst_toolkit/attacks/ensemble/data_utils.py (1)
  • load_dataframe (31-52)
examples/ensemble_attack/run_attack.py (3)
src/midst_toolkit/attacks/ensemble/data_utils.py (1)
  • load_dataframe (31-52)
examples/ensemble_attack/real_data_collection.py (1)
  • collect_population_data_ensemble (145-256)
examples/ensemble_attack/run_shadow_model_training.py (1)
  • run_shadow_model_training (83-134)
🪛 Ruff (0.14.10)
examples/ensemble_attack/test_attack_model.py

135-135: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue

(S301)


224-226: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (44)
examples/ensemble_attack/run_metaclassifier_training.py (2)

68-68: LGTM!

Good addition of logging after loading shadow model data to aid debugging and traceability.


77-80: LGTM!

Helpful logging addition for tracing data loading with size information.

examples/ensemble_attack/real_data_collection.py (7)

7-7: LGTM!

Appropriate imports added for logging functionality.

Also applies to: 14-14


64-66: LGTM!

Documentation improvements clarify the meaning of data_split parameter.

Also applies to: 82-88


110-121: LGTM!

Enhanced documentation for data_splits parameter improves clarity.


140-142: LGTM!

Simplified append logic looks correct.


149-149: LGTM!

Updated function signature and comprehensive docstring for the new original_repo_population parameter.

Also applies to: 156-174


198-213: LGTM!

Good addition of logging for population data collection and the concatenation with original_repo_population.


231-231: LGTM!

Helpful logging for challenge data collection with splits information.

src/midst_toolkit/attacks/ensemble/rmia/rmia_calculation.py (6)

7-11: LGTM!

Good additions: Sequence for type hints, Pool for multiprocessing, and FloatDType type alias for clearer typing.

Also applies to: 20-21


28-70: LGTM!

Well-implemented batched Gower distance computation. Pre-allocating the output matrix and processing in chunks is an effective pattern for reducing peak memory usage. The batch size of 5000 used at the call site (line 131) is reasonable.


73-133: LGTM!

Good implementation of a multiprocessing-friendly wrapper. Creating a copy of df_synthetic (line 118) before modifications is correct to avoid mutation issues. The tuple-based argument passing is appropriate for Pool.imap_unordered.


136-215: LGTM!

Solid multiprocessing implementation. Using imap_unordered with index tracking in a dict, then reconstructing original order is the correct pattern for parallel processing where order matters. The fallback to sequential processing when use_multiprocessing=False is useful for debugging and testing.


334-340: LGTM!

Updated to work with DataFrames directly instead of TrainingResult objects, aligning with the refactored data storage approach.


394-396: LGTM!

Logging additions and updates to use shadow_training_data_ids for mask creation are consistent with the refactored data handling approach.

Also applies to: 406-406, 422-427, 455-455

src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py (3)

187-187: LGTM!

Good refactor to store only synthetic_data (DataFrame) instead of the full TrainingResult object. This reduces memory usage and pickle file sizes. The assertion on line 182 ensures the data is not None before appending.


299-299: LGTM!

Consistent with the change in train_fine_tuned_shadow_models. The assertion on line 293 validates the data before appending.


441-442: LGTM!

Trailing whitespace change has no functional impact.

examples/ensemble_attack/run_train.sh (3)

6-12: LGTM, but verify GPU availability.

Resource increases align with the ensemble attack requirements. The specific gpu:a100:1 request may fail if A100 GPUs are unavailable on the cluster. Consider using a more generic GPU request or documenting the A100 requirement.


15-15: LGTM!

Useful memory logging for debugging resource allocation.


24-24: LGTM!

Config name updated to target the 10k data experiment configuration.

tests/integration/attacks/ensemble/test_shadow_model_training.py (1)

107-109: LGTM!

Test assertions correctly updated to validate DataFrame type and expected length. Note: this test doesn't check for None, but the production code assertion at line 293 in shadow_model_training.py ensures this won't happen.

examples/ensemble_attack/run_attack.py (3)

14-14: LGTM!

Import added to support the new data loading functionality.


27-34: LGTM!

Good addition of loading the original repository population data and passing it to collect_population_data_ensemble. The comment clearly explains why this is needed (to provide a larger population dataset for DOMIAS).

Also applies to: 41-41


81-85: LGTM!

Correctly loads the master challenge training data and passes it to the updated run_shadow_model_training function signature.

examples/ensemble_attack/run_shadow_model_training.py (3)

5-5: LGTM!

The added pandas import is necessary to support the new DataFrame type hint in the function signature.


83-95: LGTM!

Good refactor to accept df_challenge_train as a parameter instead of loading from disk. This aligns with the broader PR goal of passing DataFrames directly rather than loading internally, improving testability and flexibility.


114-128: LGTM!

The updated call to train_three_sets_of_shadow_models correctly passes df_challenge_train as master_challenge_data, which aligns with the function's signature in shadow_model_training.py.

tests/unit/attacks/ensemble/test_rmia.py (6)

45-53: LGTM!

The model_data fixture correctly stores DataFrames directly instead of wrapping them in mock objects, aligning with the refactored API that expects list[pd.DataFrame] for model_data.


77-92: LGTM!

The rmia_signal_data fixture is correctly updated to store DataFrames directly in fine_tuned_results and trained_results lists, consistent with the new data model.


151-169: LGTM!

Good test updates:

  • Using list(...) to extract DataFrames from model_data
  • use_multiprocessing=False ensures mocks work correctly in the main process
  • dtype=np.float32 in expected arrays matches the function's default dtype

173-179: LGTM!

Correctly accessing DataFrames directly from model_data instead of through .synthetic_data attribute.


181-216: LGTM!

The sampling test is well-updated with:

  • Direct DataFrame access
  • use_multiprocessing=False for deterministic behavior
  • Enhanced assert_frame_equal with descriptive obj parameter for better debugging

218-237: LGTM!

The missing categorical column test correctly uses list(...) to extract DataFrames from the fixture.

examples/ensemble_attack/configs/experiment_config.yaml (2)

20-22: LGTM!

Good addition of attack_rmia_shadow_training_data_choice option with clear options documented in the comment. This provides flexibility for controlling RMIA shadow training dataset selection.


48-55: LGTM!

Good expansion of challenge_splits and folder_ranges to support the test phase data collection. The ranges are clearly structured.

examples/ensemble_attack/test_attack_model.py (8)

22-45: LGTM!

Good extraction of result saving logic into a dedicated helper function. The function handles both saving probabilities and optionally saving the TPR@FPR=0.1 score.


47-77: LGTM!

Clean helper function for extracting and dropping ID columns. Good use of assertions for validation.


114-142: LGTM!

Good implementation of load_trained_rmia_shadows_for_test_phase. The function correctly checks existence of all models before loading and returns early with an empty list if any model is missing.

Regarding the static analysis hint about pickle (S301): this is internal research tooling loading models from known paths, so the security risk is acceptable in this context.


145-184: LGTM!

Well-structured helper function for collecting challenge and train data with clear logging.


187-228: LGTM!

Good implementation of select_challenge_data_for_training with clear documentation of the three options. The ValueError for invalid choices provides a helpful error message.

Regarding the static analysis hint (TRY003): the detailed error message is appropriate here as it helps users understand the valid options.


311-315: LGTM!

Good defensive handling to limit synthetic data size based on config. Using .head() preserves consistency.


325-334: LGTM!

Good optimization to reuse existing shadow models when available, avoiding redundant training.


353-357: LGTM!

Good addition of loading reference population data for DOMIAS signals computation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants